package com.chatplus.application.aiprocessor.handler;

import com.chatplus.application.aiprocessor.constant.AIConstants;
import com.chatplus.application.aiprocessor.platform.chat.ChatProcessorService;
import com.chatplus.application.aiprocessor.provider.ChatAiProcessorServiceProvider;
import com.chatplus.application.aiprocessor.util.ChatWebSocketUtil;
import com.chatplus.application.aiprocessor.util.WebSocketManager;
import com.chatplus.application.common.constant.GroupCacheNames;
import com.chatplus.application.common.domain.response.SensitiveWordFilterResultResponse;
import com.chatplus.application.common.enumeration.UserStatusEnum;
import com.chatplus.application.common.logging.SouthernQuietLogger;
import com.chatplus.application.common.logging.SouthernQuietLoggerFactory;
import com.chatplus.application.common.util.CacheGroupUtils;
import com.chatplus.application.common.util.spring.SpringCtxUtils;
import com.chatplus.application.domain.dto.AdminConfigDto;
import com.chatplus.application.domain.dto.RoleContextDto;
import com.chatplus.application.domain.dto.ws.ChatRecordMessage;
import com.chatplus.application.domain.dto.ws.WsChatMessage;
import com.chatplus.application.domain.entity.account.UserEntity;
import com.chatplus.application.domain.entity.chat.ChatHistoryEntity;
import com.chatplus.application.domain.entity.chat.ChatModelEntity;
import com.chatplus.application.domain.entity.chat.ChatRoleEntity;
import com.chatplus.application.domain.request.ws.ChatWebSocketRequest;
import com.chatplus.application.enumeration.AiPlatformEnum;
import com.chatplus.application.enumeration.MessageTypeEnum;
import com.chatplus.application.service.account.UserProductLogService;
import com.chatplus.application.service.account.UserService;
import com.chatplus.application.service.basedata.SensitiveWordService;
import com.chatplus.application.service.chat.ChatHistoryService;
import com.chatplus.application.service.chat.ChatModelService;
import com.chatplus.application.service.chat.ChatRoleService;
import com.chatplus.application.util.ConfigUtil;
import com.fasterxml.jackson.databind.ObjectMapper;
import org.apache.commons.collections4.CollectionUtils;
import org.apache.commons.lang3.StringUtils;
import org.jetbrains.annotations.NotNull;
import org.redisson.api.RedissonClient;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;
import org.springframework.web.socket.CloseStatus;
import org.springframework.web.socket.WebSocketMessage;
import org.springframework.web.socket.WebSocketSession;
import org.springframework.web.socket.handler.AbstractWebSocketHandler;

import java.io.IOException;
import java.util.*;
import java.util.concurrent.ConcurrentHashMap;

/**
 * AI聊天Socket会话处理器
 * 基于Spring WebSocket
 */
@Component
public class ChatWebSocketHandler extends AbstractWebSocketHandler {
    private static final SouthernQuietLogger LOGGER = SouthernQuietLoggerFactory.getLogger(ChatWebSocketHandler.class);
    //在线总数
    private static int onlineCount;
    //chatId -目前是按浏览器随机生成
    private String chatId;
    /**
     * 当前对话的sessionId
     */
    private String sessionId;
    /**
     * 当前对话平台
     */
    private AiPlatformEnum platform;

    /**
     * 使用map对象，便于根据userId来获取对应的WebSocket，或者放redis里面
     */
    private static final Map<String, AiPlatformEnum> platformMap = new ConcurrentHashMap<>();
    private final ChatModelService chatModelService;
    private final ChatRoleService chatRoleService;
    private final ChatHistoryService chatHistoryService;
    private final UserService userService;
    private final SensitiveWordService sensitiveWordService;
    private final RedissonClient redissonClient;
    private final ObjectMapper objectMapper;
    private final UserProductLogService userProductLogService;

    @Autowired
    public ChatWebSocketHandler(ChatModelService chatModelService,
                                ChatRoleService chatRoleService,
                                ChatHistoryService chatHistoryService,
                                UserService userService,
                                SensitiveWordService sensitiveWordService,
                                RedissonClient redissonClient,
                                ObjectMapper objectMapper,
                                UserProductLogService userProductLogService) {
        this.chatModelService = chatModelService;
        this.chatRoleService = chatRoleService;
        this.chatHistoryService = chatHistoryService;
        this.userService = userService;
        this.sensitiveWordService = sensitiveWordService;
        this.redissonClient = redissonClient;
        this.objectMapper = objectMapper;
        this.userProductLogService = userProductLogService;
    }

    /**
     * 获取当前连接数
     */
    public static synchronized int getOnlineCount() {
        return onlineCount;
    }

    /**
     * 当前连接数加一
     */
    public static synchronized void addOnlineCount() {
        ChatWebSocketHandler.onlineCount++;
    }

    /**
     * 当前连接数减一
     */
    public static synchronized void subOnlineCount() {
        ChatWebSocketHandler.onlineCount--;
    }

    /**
     * 会话连接成功后
     */
    @Override
    public void afterConnectionEstablished(@NotNull WebSocketSession session) {
        ChatWebSocketRequest chatWebSocketRequest = (ChatWebSocketRequest)
                session.getAttributes().get(ChatWebSocketRequest.WS_URL_PATH);
        this.chatId = chatWebSocketRequest.getChatId();
        this.sessionId = chatWebSocketRequest.getSessionId();
        if (!initHandler(chatWebSocketRequest)) {
            optClose();
            return;
        }
        WebSocketManager.add(sessionId, session);
        // 服务渠道
        if (platformMap.containsKey(sessionId)) {
            platformMap.remove(sessionId);
            platformMap.put(sessionId, platform);
        } else {
            platformMap.put(sessionId, platform);
            addOnlineCount();
        }
        // 会话上下文处理
        LOGGER.message("建立连接").context("连接ID", this.chatId)
                .context("当前连接数", Optional.of(getOnlineCount())).info();
    }

    /**
     * 处理会话发送来的消息
     */
    @Override
    public void handleMessage(@NotNull WebSocketSession session, @NotNull WebSocketMessage message) {
        if (message.getPayloadLength() == 0) {
            return;
        }
        ChatWebSocketRequest chatWebSocketRequest = (ChatWebSocketRequest)
                session.getAttributes().get(ChatWebSocketRequest.WS_URL_PATH);
        WsChatMessage wsChatMessage;
        try {
            wsChatMessage = objectMapper.readValue(message.getPayload().toString(), WsChatMessage.class);
        } catch (Exception e) {
            wsChatMessage = new WsChatMessage(MessageTypeEnum.CHAT, message.getPayload().toString());
        }
        try {
            // 心跳消息
            if (MessageTypeEnum.HEARTBEAT.equals(wsChatMessage.getType())) {
                return;
            }
            LOGGER.message("收到消息").context("连接ID", this.chatId)
                    .context("参数列表", session.getAttributes())
                    .context("消息", message.getPayload()).info();
            chatWebSocketRequest.setCurrentPrompt(wsChatMessage.getContent());
            // ------------------------用户相关校验------------------------
            String verifyResult = verifyUserRequest(chatWebSocketRequest);
            if (StringUtils.isNotEmpty(verifyResult)) {
                ChatWebSocketUtil.replyErrorMessage(sessionId, verifyResult);
                return;
            }
            // ------------------------用户相关校验结束------------------------
            //会话上下文处理
            handleChatContext(chatWebSocketRequest);
            // 执行逻辑处理
            ChatAiProcessorServiceProvider chatAiProcessorServiceProvider = SpringCtxUtils.getBean(ChatAiProcessorServiceProvider.class);
            ChatProcessorService processorService = chatAiProcessorServiceProvider.getAiProcessorServiceByRequest(chatWebSocketRequest);
            processorService.processStream();
        } catch (Exception e) {
            ChatWebSocketUtil.replyErrorMessage(sessionId);
            LOGGER.message("收到消息处理异常").exception(e).error();
        }
    }

    /**
     * 会话连接发送错误
     */
    @Override
    public void handleTransportError(WebSocketSession session, @NotNull Throwable exception) throws IOException {
        LOGGER.message("连接发送错误").context("连接信息", session.getAttributes()).context("exception", exception.getMessage()).error();
        ChatWebSocketUtil.replyErrorMessage(sessionId, "连接发送错误,等待管理员修复！");
        WebSocketManager.removeAndClose(sessionId);
    }

    /**
     * 会话连接关闭后
     */
    @Override
    public void afterConnectionClosed(@NotNull WebSocketSession session, @NotNull CloseStatus closeStatus) throws IOException {
        if (platformMap.containsKey(sessionId)) {
            platformMap.remove(sessionId);
            subOnlineCount();
        }
        WebSocketManager.removeAndClose(sessionId);
        LOGGER.message("断开连接").context("连接ID", this.chatId).context("当前连接数", getOnlineCount()).info();
    }

    /**
     * 校验用户请求
     * 1. 敏感字处理
     * 2. 用户信息校验
     * 3. 套餐余额校验
     *
     * @param request 请求
     */
    private String verifyUserRequest(ChatWebSocketRequest request) {
        // 敏感字处理
        AdminConfigDto adminConfigDto = CacheGroupUtils.get(GroupCacheNames.SYS_AI_SETTING, AIConstants.SYSTEM_CONFIG_REDIS_KEY);
        if (adminConfigDto == null) {
            return "加载系统配置失败，连接已关闭！！！";
        }
        if (adminConfigDto.isEnabledSensitiveWord()) {
            SensitiveWordFilterResultResponse wordFilterResultResponse = sensitiveWordService.filter(request.getCurrentPrompt());
            if (wordFilterResultResponse.isContainsSensitiveWord()) {
                LOGGER.message("用户输入包含敏感词").context("input", request.getCurrentPrompt())
                        .context("敏感词", wordFilterResultResponse.getSensitiveWordList()).info();
                return "您的输入包含敏感词，请重新输入！如果触发误报，请联系管理员！";
            }
        }
        // 用户信息校验
        Long userId = request.getUserId();
        UserEntity userEntity = userService.getById(userId);
        if (userEntity == null) {
            return "非法用户，请联系管理员！";
        }
        if (userEntity.getStatus() != UserStatusEnum.OK) {
            return "您的账号已经被禁用，如果疑问，请联系管理员！";
        }
        // 套餐余额校验
        Integer power = request.getPower();
        int chatCalls = userProductLogService.getUserChatPower(userId);
        if (chatCalls <= 0) {
            return "您的对话次数已经用尽，请联系管理员或者充值点卡继续对话！";
        }
        if (chatCalls < power) {
            return String.format("您当前剩余算力（%d）已不足以支付当前模型的单次对话需要消耗的算力（%d）！"
                    , chatCalls, power);
        }
        return null;
    }

    /**
     * 操作进行下线
     */
    private void optClose() {
        WebSocketSession session = WebSocketManager.get(sessionId);
        // 判断当前连接是否还在线
        if (session != null && session.isOpen()) {
            try {
                // 关闭连接
                session.close(CloseStatus.NORMAL);
            } catch (Exception e) {
                LOGGER.message("关闭session异常").exception(e).error();
            }
        }
    }

    /**
     * 会话上下文处理
     * 1. 角色上下文处理
     * 2. 历史上下文处理
     * 3. 当前本次的提问
     */
    private void handleChatContext(ChatWebSocketRequest request) {
        AdminConfigDto adminConfigDto = ConfigUtil.getAdminConfig();
        // ------------- 角色上下文处理 -------------
        ChatRoleEntity chatRoleEntity = chatRoleService.getById(request.getRoleId());
        if (chatRoleEntity == null || !chatRoleEntity.getEnable()) {
            ChatWebSocketUtil.replyErrorMessage(sessionId, "加载系统配置失败，连接已关闭！！！");
            optClose();
            return;
        }
        List<ChatRecordMessage> recordMessageList = new ArrayList<>();
        if (StringUtils.isNotEmpty(chatRoleEntity.getHelloMsg())) {
            recordMessageList.add(ChatRecordMessage.builder().system(chatRoleEntity.getHelloMsg()).build());
        }
        List<RoleContextDto> roleContext = chatRoleEntity.getContext();
        if (CollectionUtils.isNotEmpty(roleContext)) {
            for (RoleContextDto roleContextDto : roleContext) {
                if (roleContextDto.getRole().equals("user") && StringUtils.isNotEmpty(roleContextDto.getContent())) {
                    recordMessageList.add(ChatRecordMessage.builder().prompt(roleContextDto.getContent()).build());
                }
                if (roleContextDto.getRole().equals("assistant") && StringUtils.isNotEmpty(roleContextDto.getContent())) {
                    recordMessageList.add(ChatRecordMessage.builder().reply(roleContextDto.getContent()).build());
                }
            }
        }
        // ------------- 历史上下文处理 -------------
        // 如果是函数的回复则不加入上下文的问答
        if (adminConfigDto.getContextDeep() > 0 && adminConfigDto.isEnableContext()) {
            List<ChatHistoryEntity> historyList = chatHistoryService.getChatContextHistoryList(request.getUserId(), chatId);
            Map<Long, String> promptMap = historyList.stream()
                    .limit(adminConfigDto.getContextDeep() * 2L) // 取前十条记录
                    .filter(history -> history.getType().equals("prompt")) // 过滤掉 reply 类型的记录
                    .sorted(Comparator.comparing(ChatHistoryEntity::getCreatedAt)) // 再次根据时间戳升序排序
                    .collect(ConcurrentHashMap::new,
                            (m, v) -> m.put(v.getItemId(), v.getContent()),
                            ConcurrentHashMap::putAll);
            recordMessageList.addAll(historyList.stream()
                    .sorted(Comparator.comparing(ChatHistoryEntity::getId))
                    .filter(history -> history.getType().equals("reply"))
                    .map(history -> {
                        String prompt = promptMap.get(history.getItemId());
                        String reply = history.getContent();
                        if (StringUtils.isNotEmpty(prompt) && StringUtils.isNotEmpty(reply)) {
                            return ChatRecordMessage
                                    .builder()
                                    .prompt(prompt)
                                    .reply(reply)
                                    .build();
                        }
                        return null;
                    })
                    .filter(Objects::nonNull)
                    .collect(ArrayList::new, ArrayList::add, ArrayList::addAll));
        }
        // 最大只能有十个上下文
        if (recordMessageList.size() >= 10) {
            int n = 10; // 获取最后10条数据
            int startIndex = Math.max(recordMessageList.size() - n, 0); // 确定起始索引
            recordMessageList = recordMessageList.subList(startIndex, recordMessageList.size());
        }
        // 当前本次的提问
        recordMessageList.add(ChatRecordMessage.builder().prompt(request.getCurrentPrompt()).build());
        request.setChatContextList(recordMessageList);
    }

    /*
     * 停止会话
     */
    public void stop(String sessionId) {
        WebSocketManager.stopSeeClient(sessionId);
    }

    /**
     * 初始化处理器
     *
     * @param request 请求
     */
    private boolean initHandler(ChatWebSocketRequest request) {
        // ------------------------初始化校验------------------------
        // 1. 判断模型是否启用
        ChatModelEntity chatModelEntity = chatModelService.getById(request.getModelId());
        if (chatModelEntity == null || !chatModelEntity.getEnabled()) {
            ChatWebSocketUtil.replyErrorMessage(sessionId, "当前AI模型暂未启用，连接已关闭！！！");
            return false;
        }
        platform = chatModelEntity.getPlatform();
        request.setPower(chatModelEntity.getPower());
        request.setChatModel(chatModelEntity);
        request.setPlatform(platform);
        request.setChannel(StringUtils.isNotEmpty(chatModelEntity.getChannel()) ? chatModelEntity.getChannel() : chatModelEntity.getPlatform().getValue());
        // 2. 判断角色是否启用
        ChatRoleEntity chatRoleEntity = chatRoleService.getById(request.getRoleId());
        if (chatRoleEntity == null || !chatRoleEntity.getEnable()) {
            ChatWebSocketUtil.replyErrorMessage(sessionId, "当前聊天角色不存在或者未启用，连接已关闭！！！");
            return false;
        }
        return true;
    }
}
